# ------------------ Dependencies ------------------
if (!requireNamespace("igraph", quietly = TRUE)) install.packages("igraph")
if (!requireNamespace("pcalg", quietly = TRUE)) install.packages("pcalg")
if (!requireNamespace("Rgraphviz", quietly = TRUE)) {
  install.packages("BiocManager")
  BiocManager::install("Rgraphviz")
}
if (!requireNamespace("bnlearn", quietly = TRUE)) install.packages("bnlearn")

library(igraph)
library(pcalg)
library(Rgraphviz)
library(bnlearn)

# ------------------ Utility Functions ------------------
validate_vertices <- function(graph, nodes) {
  if (!all(nodes %in% V(graph)$name)) {
    stop("Some nodes are not present in the graph.")
  }
}

get_children <- function(graph, nodes) {
  children <- c()  # List to store the child nodes
  for (node in nodes) {
    node_id <- which(V(graph)$name == node)
    node_children <- neighbors(graph, node_id, mode = "out")
    children <- c(children, V(graph)$name[node_children])
  }
  return(unique(children))  # Return unique child node names
}

check_combined_children <- function(cpdag_igraph, selected_nodes) {
  validate_vertices(cpdag_igraph, selected_nodes)
  
  combined_children <- c()
  for (node in selected_nodes) {
    children <- neighbors(cpdag_igraph, node, mode = "out")
    incoming_neighbors <- neighbors(cpdag_igraph, node, mode = "in")
    direct_children <- setdiff(children, incoming_neighbors)
    combined_children <- c(combined_children, direct_children)
  }
  
  combined_children <- setdiff(combined_children, selected_nodes)
  return(list(combined_children = combined_children, is_empty = length(combined_children) == 0))
}

is_complete_graph <- function(g) {
  num_vertices <- vcount(g)
  expected_edges <- num_vertices * (num_vertices - 1) / 2
  return(ecount(g) == expected_edges)
}

check_connected_component_neighbors_complete <- function(chain_component_nodes, M, g_undirected) {
  M_in_chain <- intersect(M, chain_component_nodes)
  if (length(M_in_chain) == 0) return(TRUE)
  
  subgraph_M <- induced_subgraph(g_undirected, M_in_chain)
  M_components <- components(subgraph_M)
  
  for (i in seq_len(M_components$no)) {
    component_nodes <- V(subgraph_M)$name[M_components$membership == i]
    neighbors_nodes <- unique(unlist(neighbors(g_undirected, component_nodes, mode = "all")))
    neighbors_nodes <- setdiff(neighbors_nodes, component_nodes)
    
    if (length(neighbors_nodes) > 0) {
      neighbor_subgraph <- induced_subgraph(g_undirected, neighbors_nodes)
      if (!is_complete_graph(neighbor_subgraph)) return(FALSE)
    }
  }
  TRUE
}


get_markov_boundary_cache <- function(dag, nodeset) {
  mb_list <- lapply(nodeset, function(node) mb(dag, node))
  unique_nodes <- unique(unlist(mb_list))
  return(unique(c(nodeset, unique_nodes)))
}

# Optimized Removability Check Function
is_c_removable <- function(node, boundary, G_igraph, adj_matrix_cache) {
  # Get node ID and relatives (keep debug info)
  node_id <- which(V(G_igraph)$name == node)
  parents <- V(G_igraph)$name[neighbors(G_igraph, node_id, mode = "in")]
  children <- V(G_igraph)$name[neighbors(G_igraph, node_id, mode = "out")]
  
  # Fast return condition: boundary nodes ≤ 2 or no children
  if (length(boundary) <= 2L || length(children) == 0L) return(TRUE)
  
  # Generate all node pairs and vectorize processing
  pairs <- combn(boundary, 2)
  node1 <- pairs[1, ]
  node2 <- pairs[2, ]
  
  # Vectorized check logic
  both_in_parents <- (node1 %in% parents) & (node2 %in% parents)
  relevant_pairs <- !both_in_parents
  
  # Fast return for pairs that don't need checking
  if (!any(relevant_pairs)) return(TRUE)
  
  # Use cached adjacency matrix to query
  connected <- adj_matrix_cache[cbind(node1[relevant_pairs], node2[relevant_pairs])]
  
  # Final check (any(!connected) is equivalent to there being a disconnection)
  return(!any(!connected))
}

# Get children of nodes
get_children <- function(graph, nodes) {
  node_ids <- match(nodes, V(graph)$name, nomatch = 0L)
  if (length(node_ids) == 0L) return(character(0))
  adj_nodes <- adjacent_vertices(graph, node_ids, mode = "out")
  children <- V(graph)$name[unlist(adj_nodes, use.names = FALSE)]
  unique(children)
}

# Optimized removability set check function, handling each connected component
is_c_removable_set <- function(M_names, g_bn, g_igraph) {
  # Precompute all Markov boundaries and cache
  markov_boundary <- get_markov_boundary_cache(g_bn, M_names)
  sub_nodes <- intersect(c(M_names, markov_boundary), V(g_igraph)$name)
  g_sub <- induced_subgraph(g_igraph, sub_nodes)
  
  # Decompose into connected components
  components <- decompose(g_sub)
  
  # Parallel processing for each connected component
  library(parallel)
  results <- mclapply(components, function(comp) {
    comp_nodes <- V(comp)$name
    M_current <- intersect(M_names, comp_nodes)
    
    if (length(M_current) == 0) return(1L)
    
    # Generate undirected adjacency matrix cache for the current subgraph
    ug_component <- as_undirected(comp, mode = "collapse")
    adj_matrix_cache <- as_adjacency_matrix(ug_component, sparse = FALSE)
    rownames(adj_matrix_cache) <- colnames(adj_matrix_cache) <- V(ug_component)$name
    
    t <- 1L
    B <- character(0)
    M_children <- get_children(comp, M_current)
    not_removed_nodes <- character(0)
    
    while (length(M_children) > 0) {
      removable_found <- FALSE
      
      for (x in M_current) {
        if (!x %in% V(comp)$name) next
        boundary <- intersect(get_markov_boundary_cache(g_bn, x), V(comp)$name)
        
        if (is_c_removable(x, boundary, comp, adj_matrix_cache)) {
          B <- union(B, x)
          M_current <- setdiff(M_current, x)
          comp <- delete_vertices(comp, x)
          removable_found <- TRUE
          
          if (length(intersect(boundary, not_removed_nodes)) == 0) {
            M_current <- unique(c(setdiff(M_current, not_removed_nodes), not_removed_nodes))
          }
        } else {
          not_removed_nodes <- c(not_removed_nodes, x)
        }
      }
      
      if (!removable_found) {
        t <- 0L
        break
      }
      
      M_current <- setdiff(M_current, B)
      markov_boundary_new <- unique(unlist(lapply(M_current, function(x) {
        setdiff(get_markov_boundary_cache(g_bn, x), B)
      })))
      
      M0 <- intersect(M_current, markov_boundary_new)
      M1 <- get_children(comp, M_current)
      
      if (length(M1) == 0) {
        break
      } else if (length(M0) > 0) {
        M_current <- union(M0, setdiff(M_current, M0))
      } else {
        break
      }
      
      M_children <- get_children(comp, M_current)
    }
    
    return(t)
  })
  
  if (any(unlist(results) == 0)) 0L else 1L
}

# ------------------ Main Function for Single M ------------------

measure_performance <- function(n, p, M_size, runs, save_path) {
  if (!dir.exists(save_path)) {
    dir.create(save_path, recursive = TRUE)
  }
  
  
  results <- matrix(0, nrow = runs, ncol = 4)
  colnames(results) <- c("CPDAG_Time", "Markov_Boundary_Time", "n", "p")
  graphs_list <- list()
  for (i in 1:runs) {
    cat( M_size, sprintf("Run %d/%d\n", i, runs))
    set.seed(123 + i)
    dag <- randomDAG(n, prob = p)
    graphs_list[[i]] <- dag
    graphNEL_dag <- as(dag, "graphNEL")
    g_bn <- as.bn(graphNEL_dag)
    node_names <- nodes(g_bn)
    g_igraph <- as.igraph(g_bn)
    cpdag <- dag2cpdag(dag)
    cpdag_igraph <- graph_from_graphnel(cpdag)
    
    set.seed(123 + i)
    M <- sample(1:n, M_size)
    M_names <- node_names[M]
    
    start1 <- Sys.time()
    result <- check_combined_children(cpdag_igraph, M_names)
    
    if (result$is_empty) {
      adj <- as_adjacency_matrix(cpdag_igraph)
      undirected <- adj * t(adj)
      g_undirected <- graph_from_adjacency_matrix(undirected, mode = "undirected")
      comp <- components(g_undirected)
      sapply(1:comp$no, function(j) {
        nodes <- V(g_undirected)$name[comp$membership == j]
        check_connected_component_neighbors_complete(nodes, M_names, g_undirected)
      })
    }
    results[i, 1] <- difftime(Sys.time(), start1, units = "secs")
    
    start2 <- Sys.time()
    t_final <- is_c_removable_set(M_names, g_bn, g_igraph)
    results[i, 2] <- difftime(Sys.time(), start2, units = "secs")
    results[i, 3] <- n
    results[i, 4] <- p
  }
  
  csv_filename <- sprintf("performance_results_n%d_p%.2f_M%d.csv", n, p, M_size)
  write.csv(results, file.path(save_path, csv_filename), row.names = FALSE)
  
  return(colMeans(results[, 1:2]))  
}

# --------------------- Visualization ---------------------
set.seed(123)
M_sizes <- seq(50, 450, by = 50)
performance_data <- t(sapply(M_sizes, function(ms) {
  measure_performance(500, 0.005, ms, 30, "path_to_save")
}))
